{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## K-Means Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.linalg import Vectors\n",
    "from pyspark.ml.clustering import KMeans\n",
    "from pyspark.ml.tuning import ParamGridBuilder\n",
    "from pyspark.ml.feature import StringIndexer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### read the iris dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def mapLibSVM(row): \n",
    "    return (row[5],Vectors.dense(row[:3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df = spark.read \\\n",
    "        .format(\"csv\") \\\n",
    "        .option(\"header\", \"true\") \\\n",
    "        .option(\"inferSchema\", \"true\") \\\n",
    "        .load(\"datasets/iris.data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----------+------------+-----------+-----------+\n",
      "|sepal_length|sepal_width|petal_length|petal_width|      label|\n",
      "+------------+-----------+------------+-----------+-----------+\n",
      "|         5.1|        3.5|         1.4|        0.2|Iris-setosa|\n",
      "|         4.9|        3.0|         1.4|        0.2|Iris-setosa|\n",
      "|         4.7|        3.2|         1.3|        0.2|Iris-setosa|\n",
      "|         4.6|        3.1|         1.5|        0.2|Iris-setosa|\n",
      "|         5.0|        3.6|         1.4|        0.2|Iris-setosa|\n",
      "|         5.4|        3.9|         1.7|        0.4|Iris-setosa|\n",
      "|         4.6|        3.4|         1.4|        0.3|Iris-setosa|\n",
      "|         5.0|        3.4|         1.5|        0.2|Iris-setosa|\n",
      "|         4.4|        2.9|         1.4|        0.2|Iris-setosa|\n",
      "|         4.9|        3.1|         1.5|        0.1|Iris-setosa|\n",
      "|         5.4|        3.7|         1.5|        0.2|Iris-setosa|\n",
      "|         4.8|        3.4|         1.6|        0.2|Iris-setosa|\n",
      "|         4.8|        3.0|         1.4|        0.1|Iris-setosa|\n",
      "|         4.3|        3.0|         1.1|        0.1|Iris-setosa|\n",
      "|         5.8|        4.0|         1.2|        0.2|Iris-setosa|\n",
      "|         5.7|        4.4|         1.5|        0.4|Iris-setosa|\n",
      "|         5.4|        3.9|         1.3|        0.4|Iris-setosa|\n",
      "|         5.1|        3.5|         1.4|        0.3|Iris-setosa|\n",
      "|         5.7|        3.8|         1.7|        0.3|Iris-setosa|\n",
      "|         5.1|        3.8|         1.5|        0.3|Iris-setosa|\n",
      "+------------+-----------+------------+-----------+-----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----------+------------+-----------+-----------+----------+\n",
      "|sepal_length|sepal_width|petal_length|petal_width|      label|labelIndex|\n",
      "+------------+-----------+------------+-----------+-----------+----------+\n",
      "|         5.1|        3.5|         1.4|        0.2|Iris-setosa|       0.0|\n",
      "|         4.9|        3.0|         1.4|        0.2|Iris-setosa|       0.0|\n",
      "|         4.7|        3.2|         1.3|        0.2|Iris-setosa|       0.0|\n",
      "|         4.6|        3.1|         1.5|        0.2|Iris-setosa|       0.0|\n",
      "|         5.0|        3.6|         1.4|        0.2|Iris-setosa|       0.0|\n",
      "|         5.4|        3.9|         1.7|        0.4|Iris-setosa|       0.0|\n",
      "|         4.6|        3.4|         1.4|        0.3|Iris-setosa|       0.0|\n",
      "|         5.0|        3.4|         1.5|        0.2|Iris-setosa|       0.0|\n",
      "|         4.4|        2.9|         1.4|        0.2|Iris-setosa|       0.0|\n",
      "|         4.9|        3.1|         1.5|        0.1|Iris-setosa|       0.0|\n",
      "|         5.4|        3.7|         1.5|        0.2|Iris-setosa|       0.0|\n",
      "|         4.8|        3.4|         1.6|        0.2|Iris-setosa|       0.0|\n",
      "|         4.8|        3.0|         1.4|        0.1|Iris-setosa|       0.0|\n",
      "|         4.3|        3.0|         1.1|        0.1|Iris-setosa|       0.0|\n",
      "|         5.8|        4.0|         1.2|        0.2|Iris-setosa|       0.0|\n",
      "|         5.7|        4.4|         1.5|        0.4|Iris-setosa|       0.0|\n",
      "|         5.4|        3.9|         1.3|        0.4|Iris-setosa|       0.0|\n",
      "|         5.1|        3.5|         1.4|        0.3|Iris-setosa|       0.0|\n",
      "|         5.7|        3.8|         1.7|        0.3|Iris-setosa|       0.0|\n",
      "|         5.1|        3.8|         1.5|        0.3|Iris-setosa|       0.0|\n",
      "+------------+-----------+------------+-----------+-----------+----------+\n",
      "only showing top 20 rows\n",
      "\n",
      "+-----+-------------+\n",
      "|label|     features|\n",
      "+-----+-------------+\n",
      "|  0.0|[5.1,3.5,1.4]|\n",
      "|  0.0|[4.9,3.0,1.4]|\n",
      "|  0.0|[4.7,3.2,1.3]|\n",
      "|  0.0|[4.6,3.1,1.5]|\n",
      "|  0.0|[5.0,3.6,1.4]|\n",
      "|  0.0|[5.4,3.9,1.7]|\n",
      "|  0.0|[4.6,3.4,1.4]|\n",
      "|  0.0|[5.0,3.4,1.5]|\n",
      "|  0.0|[4.4,2.9,1.4]|\n",
      "|  0.0|[4.9,3.1,1.5]|\n",
      "|  0.0|[5.4,3.7,1.5]|\n",
      "|  0.0|[4.8,3.4,1.6]|\n",
      "|  0.0|[4.8,3.0,1.4]|\n",
      "|  0.0|[4.3,3.0,1.1]|\n",
      "|  0.0|[5.8,4.0,1.2]|\n",
      "|  0.0|[5.7,4.4,1.5]|\n",
      "|  0.0|[5.4,3.9,1.3]|\n",
      "|  0.0|[5.1,3.5,1.4]|\n",
      "|  0.0|[5.7,3.8,1.7]|\n",
      "|  0.0|[5.1,3.8,1.5]|\n",
      "+-----+-------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "indexer = StringIndexer(inputCol=\"label\", outputCol=\"labelIndex\")\n",
    "indexer = indexer.fit(df).transform(df)\n",
    "indexer.show()\n",
    "df = indexer.rdd.map(mapLibSVM).toDF([\"label\", \"features\"])\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Trains a k-means model (Estimator).\n",
    "kmeans = KMeans().setK(3).setSeed(3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = kmeans.fit(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Within Set Sum of Squared Errors = 69.51236666666772\n"
     ]
    }
   ],
   "source": [
    "# Evaluate clustering by computing Within Set Sum of Squared Errors.\n",
    "wssse = model.computeCost(df)\n",
    "print(\"Within Set Sum of Squared Errors = \" + str(wssse))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cluster Centers: \n",
      "[ 5.006  3.418  1.464]\n",
      "[ 5.86833333  2.74        4.38166667]\n",
      "[ 6.8525  3.07    5.6925]\n"
     ]
    }
   ],
   "source": [
    "# Shows the result.\n",
    "centers = model.clusterCenters()\n",
    "print(\"Cluster Centers: \")\n",
    "for center in centers:\n",
    "    print(center)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "result = model.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-----+\n",
      "|prediction|label|\n",
      "+----------+-----+\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "|         0|  0.0|\n",
      "+----------+-----+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = result.select([\"prediction\",\"label\"])\n",
    "predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Apache Toree - PySpark",
   "language": "python",
   "name": "apache_toree_pyspark"
  },
  "language_info": {
   "codemirror_mode": "text/x-ipython",
   "file_extension": ".py",
   "mimetype": "text/x-ipython",
   "name": "python",
   "pygments_lexer": "python",
   "version": "3.6.2\n"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
